#include "stdafx.h"	
#ifdef BZSLIB_WINKERNEL
#include "driver.h"
#include "device.h"
#include "irpqueue.h"
#include "trace.h"
#include <wdmsec.h>

using namespace BazisLib;
using namespace BazisLib::DDK;

Device::Device(DEVICE_TYPE DeviceType,
	   const wchar_t *pwszDeviceName,
	   ULONG DeviceCharacteristics,
	   bool bExclusive,
	   ULONG AdditionalDeviceFlags) :
	m_DeviceType(DeviceType),
	m_pwszDeviceName(pwszDeviceName),
	m_DeviceCharacteristics(DeviceCharacteristics),
	m_bExclusive(bExclusive),
	m_AdditionalDeviceFlags(AdditionalDeviceFlags),

	m_pDeviceObject(NULL),
	m_pNextDevice(NULL),
	m_pUnderlyingPDO(NULL),

	m_pDriver(NULL),
	m_bDeletePending(false),
	m_OutstandingIORequestCount(0),

	m_StopAllowed(SynchronizationEvent, true),
	m_DeleteAllowed(SynchronizationEvent, true),

	m_bInterfaceEnabled(false),
	m_bDestroyObjectAfterLastRequest(false),

	m_pDeviceQueue(NULL),
	m_hServiceThread(0),
	m_pServiceThread(NULL),

	m_InitializationStatus(STATUS_SUCCESS)
{
#ifdef _DEBUG
	m_pszDeviceDebugName = 0;
#endif
	memset(&m_InterfaceName, 0, sizeof(m_InterfaceName));

	if (!m_StopAllowed.Valid() || !m_DeleteAllowed.Valid())
		ReportInitializationError(STATUS_INSUFFICIENT_RESOURCES);
}

bool Device::ReportInitializationError(NTSTATUS status)
{
	if (NT_SUCCESS(status))
		return false;
	m_InitializationStatus = status;
	return true;
}


NTSTATUS Device::RegisterDevice(Driver *pDriver, bool bCompleteInitialization, const wchar_t *pwszLinkPath)
{
	DEBUG_TRACE(TRACE_DEVICE_REGISTRATION, ("Device::RegisterDevice() started\n"));
	if (!NT_SUCCESS(m_InitializationStatus))
	{
		DEBUG_TRACE(TRACE_DEVICE_REGISTRATION, ("Device::RegisterDevice() failing due to failed initialization (%wS)\n", MapNTStatus(m_InitializationStatus)));
		return m_InitializationStatus;
	}
	if (m_pDeviceObject)
		return STATUS_ALREADY_REGISTERED;
	if (!pDriver)
		pDriver = Driver::GetMainDriver();
	if (!pDriver)
		return STATUS_INTERNAL_ERROR;
	if (pwszLinkPath && !m_pwszDeviceName)
		return STATUS_OBJECT_NAME_INVALID;

	ASSERT(m_OutstandingIORequestCount == 0);
	ASSERT(m_StopAllowed.IsSet());
	ASSERT(m_DeleteAllowed.IsSet());
	
	m_DeleteAllowed.Reset();
	m_OutstandingIORequestCount = 1;

	String FullDevicePath;
	if (m_pwszDeviceName)
	{
		FullDevicePath = L"\\Device\\";
		FullDevicePath += m_pwszDeviceName;
	}

	NTSTATUS st;

	if (FullDevicePath.empty())
		m_DeviceCharacteristics |= FILE_AUTOGENERATED_DEVICE_NAME;

	st = IoCreateDevice(pDriver->m_DriverObject,
							 sizeof(Extension),
							 FullDevicePath.ToNTString(),
							 m_DeviceType,
							 m_DeviceCharacteristics,
							 m_bExclusive,
							 &m_pDeviceObject);


	if (!NT_SUCCESS(st))
	{
		DEBUG_TRACE(TRACE_DEVICE_REGISTRATION, ("Device::RegisterDevice(): failed call to IoCreateDevice() (%wS)\n", MapNTStatus(st)));
		return st;
	}

	m_pDeviceObject->Flags |= m_AdditionalDeviceFlags;

	Extension *pExt = (Extension *)m_pDeviceObject->DeviceExtension;
	pExt->Signature = Extension::DefaultSignature;
	pExt->pDevice = this;

	if (pwszLinkPath)
	{
		m_LinkName = pwszLinkPath;
		NTSTATUS st = IoCreateSymbolicLink(m_LinkName.ToNTString(),
										   FullDevicePath.ToNTString());
		if (!NT_SUCCESS(st))
		{
			m_LinkName.clear();
			DEBUG_TRACE(TRACE_DEVICE_REGISTRATION, ("Device::RegisterDevice(): failed call to IoCreateSymbolicLink() (%wS)\n", MapNTStatus(st)));
			return st;
		}
	}

	if (bCompleteInitialization)
		CompleteInitialization();

	m_pDriver = pDriver;
	pDriver->OnDeviceRegistered(this);
	DEBUG_TRACE(TRACE_DEVICE_REGISTRATION, ("Device::RegisterDevice(): succeeded\n"));
	return STATUS_SUCCESS;
}

void Device::CompleteInitialization()
{
	if (m_pDeviceObject)
		m_pDeviceObject->Flags &= ~DO_DEVICE_INITIALIZING;
}

bool Device::SetShortDeviceName(const wchar_t *pwszNewName)
{
	if (m_pDeviceObject)
		return false;
	m_pwszDeviceName = pwszNewName;
	return true;
}

NTSTATUS Device::AttachToDeviceStack(PDEVICE_OBJECT DeviceObject)
{
	if (!m_pDeviceObject)
		return STATUS_INVALID_DEVICE_STATE;
	if (!(m_pDeviceObject->Flags & DO_DEVICE_INITIALIZING))
		return STATUS_INVALID_DEVICE_STATE;
	if (m_pNextDevice)
		return STATUS_ALREADY_REGISTERED;
	m_pNextDevice = IoAttachDeviceToDeviceStack(m_pDeviceObject, DeviceObject);
	if (!m_pNextDevice)
		return STATUS_INVALID_DEVICE_STATE;
	m_pUnderlyingPDO = DeviceObject;
	return STATUS_SUCCESS;
}

NTSTATUS Device::AttachToDevice(String &DevicePath)
{
	if (!m_pDeviceObject)
		return STATUS_INVALID_DEVICE_STATE;
	if (m_pNextDevice)
		return STATUS_ALREADY_REGISTERED;
	NTSTATUS st = IoAttachDevice(m_pDeviceObject, DevicePath.ToNTString(), &m_pNextDevice);
	if (!NT_SUCCESS(st))
		return st;
	if (!m_pNextDevice)
		return STATUS_INVALID_DEVICE_STATE;
	m_pUnderlyingPDO = m_pNextDevice;
	return STATUS_SUCCESS;
}

NTSTATUS Device::RegisterInterface(IN CONST GUID *pGuid, IN PCUNICODE_STRING ReferenceString)
{
	if (!pGuid)
		return STATUS_INVALID_PARAMETER;
	if (!m_pUnderlyingPDO || m_InterfaceName.Buffer) 
		return STATUS_INVALID_DEVICE_STATE;
	return IoRegisterDeviceInterface(m_pUnderlyingPDO, pGuid, (PUNICODE_STRING)ReferenceString, &m_InterfaceName);
}

NTSTATUS Device::EnableInterface()
{
	if (!m_InterfaceName.Buffer) 
		return STATUS_INVALID_DEVICE_STATE;
	NTSTATUS st = IoSetDeviceInterfaceState(&m_InterfaceName, TRUE);
	if (NT_SUCCESS(st))
		m_bInterfaceEnabled = true;
	return st;
}

NTSTATUS Device::DisableInterface()
{
	if (!m_InterfaceName.Buffer) 
		return STATUS_INVALID_DEVICE_STATE;
	NTSTATUS st = IoSetDeviceInterfaceState(&m_InterfaceName, FALSE);
	if (NT_SUCCESS(st))
		m_bInterfaceEnabled = false;
	return st;
}

NTSTATUS Device::DetachDevice()
{
	if (!m_pDeviceObject)
		return STATUS_INVALID_DEVICE_STATE;
	if (!m_pNextDevice)
		return STATUS_INVALID_DEVICE_STATE;
	m_pUnderlyingPDO = NULL;
	IoDetachDevice(m_pNextDevice);
	m_pNextDevice = NULL;
	return STATUS_SUCCESS;
}

NTSTATUS Device::DeleteDevice(bool FromIRPHandler)
{
	if (m_bInterfaceEnabled)
		DisableInterface();

	m_bDeletePending = true;
	if (m_pDeviceQueue)
	{
		ASSERT(m_hServiceThread);
		m_pDeviceQueue->Shutdown();
	}

	if (m_pDeviceObject)
	{
		DecrementIOCount();
		if (FromIRPHandler)
			DecrementIOCount();
		NTSTATUS st = m_DeleteAllowed.WaitEx();
		if (FromIRPHandler)
			IncrementIOCount();
		if (st != STATUS_WAIT_0)
			return st;
	}

	if (m_pDeviceQueue)
	{
		void *pThread = NULL;
		NTSTATUS st = ObReferenceObjectByHandle(m_hServiceThread, THREAD_ALL_ACCESS, *PsThreadType, KernelMode, &pThread, NULL);
		if (!NT_SUCCESS(st))
			return st;
		KeWaitForSingleObject(pThread, Executive, KernelMode, FALSE, NULL);
		ObDereferenceObject(pThread);
		ZwClose(m_hServiceThread);
		m_hServiceThread = 0;
	}

	if (m_pDeviceObject)
	{
		if (m_LinkName.size())
			IoDeleteSymbolicLink(m_LinkName.ToNTString());

		if (m_InterfaceName.Buffer)
		{
			RtlFreeUnicodeString(&m_InterfaceName);
			memset(&m_InterfaceName, 0, sizeof(m_InterfaceName));
		}
		if (m_pNextDevice)
			DetachDevice();
		IoDeleteDevice(m_pDeviceObject);
		m_pDeviceObject = NULL;
	}

	if (m_pDeviceQueue)
	{
		delete m_pDeviceQueue;
		m_pDeviceQueue = NULL;
	}

	m_bDeletePending = false;

	if (m_pDriver)
		m_pDriver->OnDeviceUnregistered(this);

	return STATUS_SUCCESS;
}

bool Device::Valid()
{
	return (m_pDeviceObject != NULL);
}

Device::~Device()
{
	DeleteDevice(false);
	if (m_pDeviceQueue)
		delete m_pDeviceQueue;
}


NTSTATUS Device::ForwardPacketToNextDriverWithIrpCompletion(IN IncomingIrp *Irp)
{
	NTSTATUS status;

	if (!m_pNextDevice)
		return STATUS_INVALID_DEVICE_REQUEST;

	InterlockedOr(&Irp->m_Flags, IncomingIrp::LowerDriverCalled);
	IoCopyCurrentIrpStackLocationToNext(Irp->m_pIrp);
	IoSetCompletionRoutine(Irp->m_pIrp, IrpCompletingCompletionRoutine,
						   NULL, TRUE, TRUE, TRUE);

	if (Irp->m_Flags & IncomingIrp::IsPowerIrp)
		status = PoCallDriver(m_pNextDevice, Irp->m_pIrp);
	else
		status = IoCallDriver(m_pNextDevice, Irp->m_pIrp);
	return status;
}

NTSTATUS Device::CallNextDriverSynchronously(IN IncomingIrp *Irp)
{
	KEVENT   evt;
	NTSTATUS status;

	PAGED_CODE();

	if (!m_pNextDevice)
		return STATUS_INVALID_DEVICE_REQUEST;

	InterlockedOr(&Irp->m_Flags, IncomingIrp::LowerDriverCalled);

	KeInitializeEvent(&evt, NotificationEvent, FALSE);
	IoCopyCurrentIrpStackLocationToNext(Irp->m_pIrp);

	IoSetCompletionRoutine(Irp->m_pIrp, EventSettingCompletionRoutine,
						   &evt, TRUE, TRUE, TRUE);

	if (Irp->m_Flags & IncomingIrp::IsPowerIrp)
		status = PoCallDriver(m_pNextDevice, Irp->m_pIrp);
	else
		status = IoCallDriver(m_pNextDevice, Irp->m_pIrp);
	if (status == STATUS_PENDING) 
	{
	   KeWaitForSingleObject(&evt,
							 Executive,
							 KernelMode,
							 FALSE,
							 NULL);
	   status = Irp->m_pIrp->IoStatus.Status;
	}

	return status;
}


NTSTATUS Device::ForwardPacketToNextDriver(IN IncomingIrp *Irp)
{
	if (!m_pNextDevice)
	{
		NTSTATUS st = Irp->m_pIrp->IoStatus.Status;
		Irp->CompleteRequest();
		return st;
	}
	else
	{
		InterlockedOr(&Irp->m_Flags, IncomingIrp::LowerDriverCalled);
		if (Irp->m_Flags & IncomingIrp::StartNextPowerIrp)
			PoStartNextPowerIrp(Irp->m_pIrp);
		IoSkipCurrentIrpStackLocation(Irp->m_pIrp);
		if (Irp->m_Flags & IncomingIrp::IsPowerIrp)
			return PoCallDriver(m_pNextDevice, Irp->m_pIrp);
		else
			return IoCallDriver(m_pNextDevice, Irp->m_pIrp);
	}
}

NTSTATUS Device::EventSettingCompletionRoutine(IN PDEVICE_OBJECT  DeviceObject, IN PIRP pIrp, IN PVOID Context)
{
    UNREFERENCED_PARAMETER (DeviceObject);
    KeSetEvent ((PKEVENT) Context, IO_NO_INCREMENT, FALSE);
    return STATUS_MORE_PROCESSING_REQUIRED;
}

NTSTATUS Device::IrpCompletingCompletionRoutine(IN PDEVICE_OBJECT  DeviceObject, IN PIRP pIrp, IN PVOID Context)
{
    UNREFERENCED_PARAMETER (DeviceObject);
    UNREFERENCED_PARAMETER (Context);
	IoCompleteRequest(pIrp, IO_NO_INCREMENT);
    return STATUS_SUCCESS;
}

NTSTATUS Device::ProcessIRP(IN PIRP Irp, bool bIsPowerIrp)
{
	if (m_bDeletePending)
		return STATUS_DELETE_PENDING;
	if (!m_pDeviceObject)
		return STATUS_INVALID_DEVICE_STATE;

	IncrementIOCount();

	PIO_STACK_LOCATION IrpSp = IoGetCurrentIrpStackLocation(Irp);

	IncomingIrp incomingIrp(Irp,
							bIsPowerIrp,
							bIsPowerIrp && ((IrpSp->MinorFunction == IRP_MN_QUERY_POWER) || (IrpSp->MinorFunction == IRP_MN_SET_POWER)));

	NTSTATUS st = DispatchRoutine(&incomingIrp, IoGetCurrentIrpStackLocation(Irp));
	return PostProcessIRP(&incomingIrp, st, false);
}

NTSTATUS Device::PostProcessIRP(IncomingIrp *pIrp, NTSTATUS ProcessingStatus, bool FromDispatcherThread)
{
	//If our driver has marked the IRP as pending, we do not decrement the outstanding I/O request count.
	//The driver should decrement this count by calling OnPendingIRPCompleted() after it completes the corresponding IRP.
	//If our driver used SendIrpToNextDevice() without waiting for IRP to complete, and the lower driver returned
	//STATUS_PENDING, the I/O count was decremented inside SendIrpToNextDevice().
	if (ProcessingStatus == STATUS_PENDING)
	{
		ASSERT(pIrp->m_Flags & IncomingIrp::IrpMarkedPending);
		return ProcessingStatus;
	}
	ASSERT(!(pIrp->m_Flags & IncomingIrp::IrpMarkedPending));

	//We do not forward the IRP to the next driver only if it was already forwarded, or it was completed
	//by our driver.
	if (pIrp->m_Flags & (IncomingIrp::LowerDriverCalled | IncomingIrp::IrpCompleted))
	{
		if (!(pIrp->m_Flags & IncomingIrp::IrpCompleted))
			pIrp->CompleteRequest();
	}
	else
	{
		if (FromDispatcherThread)
		{
			ProcessingStatus = ForwardPacketToNextDriverWithIrpCompletion(pIrp);
			/*ProcessingStatus = CallNextDriverSynchronously(pIrp);
			pIrp->CompleteRequest();*/
		}
		else
			ProcessingStatus = ForwardPacketToNextDriver(pIrp);
	}
	DecrementIOCount();

	if (m_bDestroyObjectAfterLastRequest && !m_OutstandingIORequestCount)
		delete this;
	return ProcessingStatus;
}


void Device::IncrementIOCount()
{
    register long result = InterlockedIncrement(&m_OutstandingIORequestCount);

	ASSERT(result > 0);

	if (result == 2)
        m_StopAllowed.Reset();
}

void Device::DecrementIOCount()
{
    register long result = InterlockedDecrement(&m_OutstandingIORequestCount);

    ASSERT(result >= 0);

	switch (result)
	{
	case 0:
        m_DeleteAllowed.Set();
		break;
	case 1:
        m_StopAllowed.Set();
		break;
	}
}

NTSTATUS Device::WaitForStopEvent(bool FromIRPHandler)
{
	if (FromIRPHandler)
		DecrementIOCount();
	NTSTATUS st = m_StopAllowed.WaitEx();
	if (FromIRPHandler)
		IncrementIOCount();
	if (st != STATUS_WAIT_0)
		return st;
	return STATUS_SUCCESS;
}

NTSTATUS Device::EnqueuePacket(IN IncomingIrp *Irp)
{
	if (!Irp)
		return STATUS_INVALID_PARAMETER;
	if (!m_pDeviceQueue)
		return STATUS_INTERNAL_ERROR;
	Irp->MarkPending();
    
	if (!m_pDeviceQueue->EnqueuePacket(Irp->m_pIrp, LongToPtr(Irp->m_Flags)))
	{
		Irp->SetIoStatus(STATUS_DELETE_PENDING);
		Irp->CompleteRequest();
		return STATUS_DELETE_PENDING;
	}
	return STATUS_PENDING;
}

void Device::RequestDispatcherThreadBody(IN PVOID pParam)
{
	Device *pDevice = (Device *)pParam;
	if (!pDevice->m_pDeviceQueue)
		return;
	pDevice->m_pServiceThread = PsGetCurrentThread();
	void *pContext = 0;
	while (PIRP pIrp = pDevice->m_pDeviceQueue->DequeuePacket(&pContext))
	{
/*		if (pDevice->m_pDeviceQueue->GetPacketCount() > 1)
			__asm int 3;*/
		IncomingIrp incomingIrp(pIrp, ((PtrToLong(pContext)) | IncomingIrp::FromDispatchThread) & ~IncomingIrp::IrpMarkedPending);
		NTSTATUS st = pDevice->DispatchRoutine(&incomingIrp, IoGetCurrentIrpStackLocation(pIrp));
		pDevice->PostProcessIRP(&incomingIrp, st, true);
	}
}

NTSTATUS Device::CreateDeviceRequestQueue()
{
	if (m_pDeviceQueue)
		return STATUS_ALREADY_REGISTERED;
	m_pDeviceQueue = new IrpQueue();
	if (!m_pDeviceQueue)
		return STATUS_NO_MEMORY;

	OBJECT_ATTRIBUTES threadAttr;
	InitializeObjectAttributes(&threadAttr, NULL, OBJ_KERNEL_HANDLE, 0, NULL);

	NTSTATUS st = PsCreateSystemThread(&m_hServiceThread, THREAD_ALL_ACCESS, &threadAttr, 0, NULL, RequestDispatcherThreadBody, this);
	if (!NT_SUCCESS(st))
	{
		delete m_pDeviceQueue;
		m_pDeviceQueue = NULL;
		return st;
	}

	return STATUS_SUCCESS;
}

#include "../../bzscore/WinKernel/security.h"
#include "../../bzscore/file.h"

using namespace BazisLib::DDK::Security;

NTSTATUS Device::ApplyDACL(TranslatedAcl *pACL)
{
	if (!m_InterfaceName.Buffer)
		return STATUS_INVALID_DEVICE_STATE;
	return sApplyDACL(pACL, m_InterfaceName.Buffer);
}

NTSTATUS BazisLib::DDK::Device::sApplyDACL(class Security::TranslatedAcl *pACL, PCWSTR pLinkPath)
{
	if (!pACL)
		return STATUS_INVALID_PARAMETER;

	TypedBuffer<PSECURITY_DESCRIPTOR> pDescriptor;
	ActionStatus st;
	File f(pLinkPath, FileModes::OpenReadWrite, &st);
	HANDLE hSrc = f.DetachHandle();
	if (!hSrc)
		return st.ConvertToNTStatus();

	ULONG done = 0;
	NTSTATUS status = ZwQuerySecurityObject(hSrc, DACL_SECURITY_INFORMATION, NULL, 0, &done);
	if (!done)
	{
		ZwClose(hSrc);
		return status;
	}

	pDescriptor.EnsureSizeAndSetUsed(done);

	status = ZwQuerySecurityObject(hSrc, DACL_SECURITY_INFORMATION, pDescriptor, (ULONG)pDescriptor.GetAllocated(), &done);

	if (!NT_SUCCESS(status))
	{
		ZwClose(hSrc);
		return status;
	}

	Security::TranslatedSecurityDescriptor desc(pDescriptor);
	desc.DACL = *pACL;

	status = ZwSetSecurityObject(hSrc, DACL_SECURITY_INFORMATION, desc.BuildNTSecurityDescriptor());
	ZwClose(hSrc);

	return status;
}
#endif